# -*- coding: utf-8 -*-
"""
Created on Sat Jan 30 12:46:02 2021

@author: wzhangcd
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import math
import numpy as np
import pandas as pd

import sys
sys.path.append('./lib/')
from pkl_process import *
from utils import load_graphdata_channel_my, compute_val_loss_sttn

from time import time
import shutil
import argparse
import configparser
from tensorboardX import SummaryWriter
import os

from ST_Transformer_new import STTransformer # STTN model with linear layer to get positional embedding
from ST_Transformer_new_sinembedding import STTransformer_sinembedding #STTN model with sin()/cos() to get positional embedding, the same as "Attention is all your need"

#%%

if __name__ == '__main__':
    
    
    params_path = './Experiment/PEMS25_embed_size64' ## Path for saving network parameters
    print('params_path:', params_path)
    filename = './PEMSD7/V_25_r1_d0_w0_astcgn.npz' ## Data generated by prepareData.py
    num_of_hours, num_of_days, num_of_weeks = 1, 0, 0 ## The same setting as prepareData.py
    
    ### Training Hyparameter
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    DEVICE = device
    batch_size = 32
    learning_rate = 0.01
    epochs = 1
    
    ### Generate Data Loader
    train_loader, train_target_tensor, val_loader, val_target_tensor, test_loader, test_target_tensor, _mean, _std = load_graphdata_channel_my(
        filename, num_of_hours, num_of_days, num_of_weeks, DEVICE, batch_size)
    
    ### Adjacency Matrix Import
    adj_mx = pd.read_csv('./PEMSD7/W_25.csv', header = None)
    adj_mx = np.array(adj_mx)
    A = adj_mx
    A = torch.Tensor(A)
    
    
    ### Training Hyparameter
    in_channels = 1 # Channels of input
    embed_size = 64 # Dimension of hidden embedding features
    time_num = 288 
    num_layers = 3 # Number of ST Block
    T_dim = 12 # Input length, should be the same as prepareData.py
    output_T_dim = 12 # Output Expected length
    heads = 2 # Number of Heads in MultiHeadAttention
    cheb_K = 2 # Order for Chebyshev Polynomials (Eq 2)
    forward_expansion = 4 # Dimension of Feed Forward Network: embed_size --> embed_size * forward_expansion --> embed_size
    dropout = 0
    
    ### Construct Network
    net = STTransformer(
        A,
        in_channels, 
        embed_size, 
        time_num, 
        num_layers, 
        T_dim, 
        output_T_dim, 
        heads,
        cheb_K,
        forward_expansion,
        dropout)   
    
    # net = STTransformer_sinembedding(
    #     A,
    #     in_channels, 
    #     embed_size, 
    #     time_num, 
    #     num_layers, 
    #     T_dim, 
    #     output_T_dim, 
    #     heads,
    #     cheb_K,
    #     forward_expansion, dropout)  
    
    net.to(device)
    
    ### Training Process
    #### Load the parameter we have already learnt if start_epoch does not equal to 0
    start_epoch = 0 
    if (start_epoch == 0) and (not os.path.exists(params_path)):
        os.makedirs(params_path)
        print('create params directory %s' % (params_path))
    elif (start_epoch == 0) and (os.path.exists(params_path)):
        shutil.rmtree(params_path)
        os.makedirs(params_path)
        print('delete the old one and create params directory %s' % (params_path))
    elif (start_epoch > 0) and (os.path.exists(params_path)):
        print('train from params directory %s' % (params_path))
    else:
        raise SystemExit('Wrong type of model!')

    #### Loss Function Setting
    criterion = nn.MSELoss().to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=0.01)    
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=0.01)
    # criterion = nn.L1Loss().to('cuda:0')
    
    #### Training Log Set and Print Network, Optimizer
    sw = SummaryWriter(logdir=params_path, flush_secs=5)
    print(net) 
    print('Optimizer\'s state_dict:')
    for var_name in optimizer.state_dict():
        print(var_name, '\t', optimizer.state_dict()[var_name])
    
    
    global_step = 0
    best_epoch = 0
    best_val_loss = np.inf
    start_time = time()
    
    #### Load parameters from files
    if start_epoch > 0:
        params_filename = os.path.join(params_path, 'epoch_%s.params' % start_epoch)
        net.load_state_dict(torch.load(params_filename))
        print('start epoch:', start_epoch)
        print('load weight from: ', params_filename)
    
    #### train model
    for epoch in range(start_epoch, epochs):
        ##### Parameter Saving
        params_filename = os.path.join(params_path, 'epoch_%s.params' % epoch)
        ##### Evaluate on Validation Set
        val_loss = compute_val_loss_sttn(net, val_loader, criterion, sw, epoch) 
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            torch.save(net.state_dict(), params_filename)
            print('save parameters to file: %s' % params_filename)
    
        net.train()  # ensure dropout layers are in train mode    
        for batch_index, batch_data in enumerate(train_loader):
            
            encoder_inputs, labels = batch_data
            optimizer.zero_grad()  
            outputs = net(encoder_inputs.permute(0, 2, 1, 3))   
            loss = criterion(outputs, labels)   
            loss.backward()
            optimizer.step()
            training_loss = loss.item()
            global_step += 1
            sw.add_scalar('training_loss', training_loss, global_step)
            if global_step % 1000 == 0:
                print('global step: %s, training loss: %.2f, time: %.2fs' % (global_step, training_loss, time() - start_time))
    
    print('best epoch:', best_epoch)
      
    
  
    
  
    
  
    
  
    
  
    